Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example of sparsely observed SIR model #2457

Merged
merged 2 commits into from
May 1, 2020
Merged

Add example of sparsely observed SIR model #2457

merged 2 commits into from
May 1, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented May 1, 2020

Addresses #2426

This adds an example model with sparsely observed cumulative infections. This model is interesting because it preserves Markov structure by adding an auxiliary variable for fully observed cumulative infections (as discussed with @eb8680).

@eb8680 two notes:

  1. Regarding our models with Lucy, I think there is enough complexity that we should probably create a model in say examples/contrib/epidemiology/epi_phy.py that has a complete complex model and script. That is, I think it would be most educational to illustrate each model feature independently in the pyro.contrib.epidemiology module, and then combine them in examples/.
  2. The S2O flow looks a little like a chemical reaction with multiple products.

Tested

  • added unit tests

pyro/contrib/epidemiology/sir.py Outdated Show resolved Hide resolved
mask_t = self.mask[t] if t < self.duration else False
data_t = self.data[t] if t < self.duration else None
pyro.sample("obs_{}".format(t),
dist.Delta(state["O"]).mask(mask_t),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the auxiliary necessary? or does this just play better with the structure of CompartmentalModel?

Copy link
Member Author

@fritzo fritzo May 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl;dr The auxiliary variable is needed to preserve Markov structure.

The observations in this model are aggregated over intervals: obs=S2I[t_prev+1:t_curr+1].sum() where t_prev is the time of the last observation and t_curr is the time of the current observation. In our enumeration strategy, this would couple all t_curr-t_prev-many enumeration variables, growing exponentially in the number of variables. While the non-parallel-scan enumeration strategy could handle this without erroring, it would be prohibitively expensive, and would not allow e.g. large gaps in sensor data (as e.g. when a government shuts down or runs out of tests). The trick we're using is to add an auxiliary variable for the entire cumulative observation trajectory (with the same likelihood as in the usual SIR models), and then Delta-clamp that auxiliary to the true observations at a few sparse time steps. This makes more work for HMC adds one enumeration variable per time step and increases the complexity of variable elimination by a constant factor of Q**2, but crucially this factor is independent of gap size.

I had been struggling with this issue for a while since Lucy's model simulates 4 times per day but is observed only once. The only alternative I could see was to do parallel-scan variable elimination where each DiscreteHMM state covered the joint distribution over an entire day (four time steps), resulting in complexity Q**(2 * 4 * 2) for an SIR model or Q**(3 * 4 * 2) for an SEIR model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation. to clarify though: if all you had was occasional missing data you wouldn't need this construction. this is really for the cumulative case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. It appears the cumulative case is more common in epidemiology.

mask_t = self.mask[t] if t < self.duration else False
data_t = self.data[t] if t < self.duration else None
pyro.sample("obs_{}".format(t),
dist.Delta(state["O"]).mask(mask_t),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation. to clarify though: if all you had was occasional missing data you wouldn't need this construction. this is really for the cumulative case

@martinjankowiak martinjankowiak merged commit 5c7f1c3 into dev May 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants